{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Semantic Kernel Tool Use Example\n",
    "\n",
    "This document provides an overview and explanation of the code used to create a Semantic Kernel-based tool that integrates with Azure AI Search for Retrieval-Augmented Generation (RAG). The example demonstrates how to build an AI agent that retrieves travel documents from an Azure AI Search index, augments user queries with semantic search results, and streams detailed travel recommendations."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initializing the Environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Importing Packages\n",
    "The following code imports the necessary packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from azure.core.credentials import AzureKeyCredential\n",
    "from azure.search.documents import SearchClient\n",
    "from azure.search.documents.indexes import SearchIndexClient\n",
    "from azure.search.documents.indexes.models import SearchIndex, SimpleField, SearchFieldDataType, SearchableField\n",
    "\n",
    "from openai import AsyncOpenAI\n",
    "\n",
    "from semantic_kernel.kernel import Kernel\n",
    "from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion\n",
    "from semantic_kernel.contents.chat_history import ChatHistory\n",
    "from semantic_kernel.functions import kernel_function\n",
    "from semantic_kernel.functions.kernel_arguments import KernelArguments\n",
    "from semantic_kernel.connectors.ai import FunctionChoiceBehavior\n",
    "from semantic_kernel.contents.function_call_content import FunctionCallContent\n",
    "from semantic_kernel.contents.function_result_content import FunctionResultContent\n",
    "from semantic_kernel.agents import ChatCompletionAgent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating the Semantic Kernel and AI Service\n",
    "\n",
    "A Semantic Kernel instance is created and configured with an asynchronous OpenAI chat completion service. The service is added to the kernel for use in generating responses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the asynchronous OpenAI client\n",
    "client = AsyncOpenAI(\n",
    "    api_key=os.environ[\"GITHUB_TOKEN\"],\n",
    "    base_url=\"https://models.inference.ai.azure.com/\"\n",
    ")\n",
    "\n",
    "# Create a Semantic Kernel instance and add an OpenAI chat completion service.\n",
    "kernel = Kernel()\n",
    "chat_completion_service = OpenAIChatCompletion(\n",
    "    ai_model_id=\"gpt-4o-mini\",\n",
    "    async_client=client,\n",
    "    service_id=\"agent\",\n",
    ")\n",
    "kernel.add_service(chat_completion_service)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the Prompt Plugin\n",
    "\n",
    "The PromptPlugin is a native plugin that defines a function to build an augmented prompt using retrieval context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "KernelPlugin(name='promptPlugin', description=None, functions={})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class PromptPlugin:\n",
    "    @kernel_function(\n",
    "        name=\"build_augmented_prompt\",\n",
    "        description=\"Build an augmented prompt using retrieval context.\"\n",
    "    )\n",
    "    @staticmethod\n",
    "    def build_augmented_prompt(query: str, retrieval_context: str) -> str:\n",
    "        return (\n",
    "            f\"Retrieved Context:\\n{retrieval_context}\\n\\n\"\n",
    "            f\"User Query: {query}\\n\\n\"\n",
    "            \"Based ONLY on the above context, please provide your answer.\"\n",
    "        )\n",
    "\n",
    "# Register the plugin with the kernel.\n",
    "kernel.add_plugin(PromptPlugin(), plugin_name=\"promptPlugin\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vector Database Initialization\n",
    "\n",
    "We initialize Azure AI Search with persistent storage and add enhanced sample documents. Azure AI Search will be used to store and retrieve documents that provide context for generating accurate responses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<azure.search.documents._generated.models._models_py3.IndexingResult at 0x20cda33b080>,\n",
       " <azure.search.documents._generated.models._models_py3.IndexingResult at 0x20cdbdc1f40>,\n",
       " <azure.search.documents._generated.models._models_py3.IndexingResult at 0x20cdbdc1d00>,\n",
       " <azure.search.documents._generated.models._models_py3.IndexingResult at 0x20cdbdc2690>,\n",
       " <azure.search.documents._generated.models._models_py3.IndexingResult at 0x20cdbdc2450>]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Initialize Azure AI Search with persistent storage\n",
    "search_service_endpoint = os.getenv(\"AZURE_SEARCH_SERVICE_ENDPOINT\")\n",
    "search_api_key = os.getenv(\"AZURE_SEARCH_API_KEY\")\n",
    "index_name = \"travel-documents\"\n",
    "\n",
    "search_client = SearchClient(\n",
    "    endpoint=search_service_endpoint,\n",
    "    index_name=index_name,\n",
    "    credential=AzureKeyCredential(search_api_key)\n",
    ")\n",
    "\n",
    "index_client = SearchIndexClient(\n",
    "    endpoint=search_service_endpoint,\n",
    "    credential=AzureKeyCredential(search_api_key)\n",
    ")\n",
    "\n",
    "# Define the index schema\n",
    "fields = [\n",
    "    SimpleField(name=\"id\", type=SearchFieldDataType.String, key=True),\n",
    "    SearchableField(name=\"content\", type=SearchFieldDataType.String)\n",
    "]\n",
    "\n",
    "index = SearchIndex(name=index_name, fields=fields)\n",
    "\n",
    "# Create the index\n",
    "index_client.create_index(index)\n",
    "\n",
    "# Enhanced sample documents\n",
    "documents = [\n",
    "    {\"id\": \"1\", \"content\": \"Contoso Travel offers luxury vacation packages to exotic destinations worldwide.\"},\n",
    "    {\"id\": \"2\", \"content\": \"Our premium travel services include personalized itinerary planning and 24/7 concierge support.\"},\n",
    "    {\"id\": \"3\", \"content\": \"Contoso's travel insurance covers medical emergencies, trip cancellations, and lost baggage.\"},\n",
    "    {\"id\": \"4\", \"content\": \"Popular destinations include the Maldives, Swiss Alps, and African safaris.\"},\n",
    "    {\"id\": \"5\", \"content\": \"Contoso Travel provides exclusive access to boutique hotels and private guided tours.\"}\n",
    "]\n",
    "\n",
    "# Add documents to the index\n",
    "search_client.upload_documents(documents)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A helper function `get_retrieval_context` is defined to query the index and return the top two relevant documents based on the user query:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_retrieval_context(query: str) -> str:\n",
    "    results = search_client.search(query)\n",
    "    context_strings = []\n",
    "    for result in results:\n",
    "        context_strings.append(f\"Document: {result['content']}\")\n",
    "    return \"\\n\\n\".join(context_strings) if context_strings else \"No results found\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting the Function Choice Behavior \n",
    "\n",
    "In Semantic Kernel, we have the ability to have some control of the agent choice of functions. This is done by using the `FunctionChoiceBehavior` class. \n",
    "\n",
    "The code below sets it to `Auto` which allows the agent to choose among the available functions or not choose any. \n",
    "\n",
    "This can also be set to:\n",
    "`FunctionChoiceBehavior.Required` - to require the agent to choose at least one function \n",
    "`FunctionChoiceBehavior.NoneInvoke` - instructs the agent to not choose any function. (good for testing)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "settings = kernel.get_prompt_execution_settings_from_service_id(\"agent\")\n",
    "settings.function_choice_behavior = FunctionChoiceBehavior.Auto()\n",
    "arguments = KernelArguments(settings=settings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "AGENT_NAME = \"TravelAgent\"\n",
    "AGENT_INSTRUCTIONS = (\n",
    "    \"Answer travel queries using the provided context. If context is provided, do not say 'I have no context for that.'\"\n",
    ")\n",
    "agent = ChatCompletionAgent(\n",
    "    service_id=\"agent\",\n",
    "    kernel=kernel,\n",
    "    name=AGENT_NAME,\n",
    "    instructions=AGENT_INSTRUCTIONS,\n",
    "    arguments=arguments,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A helper function `get_augmented_prompt` forces a call to the plugin to build the augmented prompt. It directly calls the static plugin method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "async def get_augmented_prompt(query: str) -> str:\n",
    "    retrieval_context = get_retrieval_context(query)\n",
    "    return PromptPlugin.build_augmented_prompt(query, retrieval_context)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Running the Agent with Streaming Chat History\n",
    "The main asynchronous loop creates a chat history for the conversation and, for each user input, first adds the augmented prompt (as a system message) to the chat history so that the agent sees the retrieval context. The user message is also added, and then the agent is invoked using streaming. The output is printed as it streams in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# User: 'Can you explain Contoso's travel insurance coverage?'\n",
      "# Assistant - TravelAgent: 'Contoso's travel insurance covers medical emergencies, trip cancellations, and lost baggage. This means that if you encounter a medical issue while traveling, have to cancel your trip for unforeseen reasons, or lose your luggage, you are protected under their insurance plan.\n",
      "============================================================\n",
      "\n",
      "# User: 'What is Neural Network?'\n",
      "# Assistant - TravelAgent: 'I have no context for that.\n",
      "============================================================\n",
      "\n"
     ]
    }
   ],
   "source": [
    "async def main():\n",
    "    # Create a chat history.\n",
    "    chat_history = ChatHistory()\n",
    "    \n",
    "    user_inputs = [\n",
    "        \"Can you explain Contoso's travel insurance coverage?\", # Retrieval context available.\n",
    "        \"What is Neural Network?\" # No retrieval context available.\n",
    "    ]\n",
    "    \n",
    "    for user_input in user_inputs:\n",
    "        # Obtain the augmented prompt.\n",
    "        augmented_prompt = await get_augmented_prompt(user_input)\n",
    "        # Add the augmented prompt as a system message so the agent sees the retrieval context.\n",
    "        chat_history.add_system_message(augmented_prompt)\n",
    "        # Also add the user message.\n",
    "        chat_history.add_user_message(user_input)\n",
    "        \n",
    "        print(f\"# User: '{user_input}'\")\n",
    "        \n",
    "        agent_name: str | None = None\n",
    "        print(\"# Assistant - \", end=\"\")\n",
    "        async for content in agent.invoke_stream(chat_history):\n",
    "            if not agent_name:\n",
    "                agent_name = content.name\n",
    "                print(f\"{agent_name}: '\", end=\"\")\n",
    "            if (\n",
    "                not any(isinstance(item, (FunctionCallContent, FunctionResultContent))\n",
    "                        for item in content.items)\n",
    "                and content.content.strip()\n",
    "            ):\n",
    "                print(f\"{content.content}\", end=\"\", flush=True)\n",
    "        print(\"\\n\" + \"=\"*60 + \"\\n\")\n",
    "        \n",
    "        # Clear chat history for the next query.\n",
    "        chat_history = ChatHistory()\n",
    "\n",
    "await main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
